
-- set search_path to sm_sc;
-- drop function if exists sm_sc.__fv_d_redistr_softmax_1d_dloss_dindepdt(float[], float[], float[]);
create or replace function sm_sc.__fv_d_redistr_softmax_1d_dloss_dindepdt
(
  i_depdt                  float[]                 ,                     -- softmax 输出
  i_dloss_ddepdt           float[]                 ,                     -- 此入参传入 dloss/dindepdt, 用于 softmax 直接求取 dloss/ddepdt
  i_indepdt                float[]                                       -- softmax 算子的输入，来自上一层算子的输出
)
returns float[]     -- 输出列序与 i_indepdt 枚举值 one_hot 序一致
as
$$
declare 
  v_y                 float[];
  v_dloss_dindepdt          float[];

begin
  if current_setting('pg4ml._v_is_debug_check', true) = '1'
  then
    -- 审计维度
    if array_ndims(i_indepdt) <> 2 or array_length(i_indepdt, 1) <> 1 and array_length(i_indepdt, 2) <> 1
      or array_ndims(i_depdt) <> 2 or array_length(i_depdt, 1) <> 1 and array_length(i_depdt, 2) <> 1
      or array_dims(i_dloss_ddepdt) <> array_dims(i_indepdt) and array_dims(i_dloss_ddepdt) <> array_dims(i_depdt)
    then 
      raise exception 'array_ndims and (array_length(, 1) or array_length(, 2)) should be 2 and 1';
    end if;
  end if;
  
  if i_depdt is null 
  then 
    i_depdt := sm_sc.fv_redistr_softmax(i_indepdt);
  end if;

  -- if array_length(i_depdt, 2) = 1
  -- then
  --   v_y := |^~| i_depdt;
  -- else
  --   v_y := i_depdt;
  --   i_depdt := |^~| i_depdt;
  -- end if;
  -- 
  -- -- -- -- i_depdt |**| v_y 是对称矩阵，开销可优化减半
  -- -- to avoid 'ERROR: array size exceeds the maximum allowed (134217727 = 128 * 1024 * 1024 -1)'
  -- v_dloss_dindepdt := (sm_sc.fv_eye(0.0 :: float, 2, variadic sm_sc.fv_mx_ele_2d_2_1d(v_y)) -` (i_depdt |**| v_y))  *` i_dloss_ddepdt; -- ~=` 8
  -- 
  -- if array_length(i_dloss_ddepdt, 2) = 1
  -- then
  --   return |^~| (sm_sc.fv_aggr_slice_sum(v_dloss_dindepdt, array[array_length(v_dloss_dindepdt, 1), 1]));
  -- else
  --   return |^~| (sm_sc.fv_aggr_slice_sum(v_dloss_dindepdt, array[1, array_length(v_dloss_dindepdt, 2)]));
  -- end if;
  
  -- 零概率的预测值将导致 softmax 求导失效，所以用一个极小概率对零值替换截断
  i_depdt := i_depdt @>` (exp(-2.0e2) :: float);
  
  return 
    i_depdt *` (i_dloss_ddepdt -` (|@+| (i_depdt *` i_dloss_ddepdt)))
  ;

end
$$
language plpgsql stable
parallel safe
cost 100;
-- -- set search_path to sm_sc;
-- select sm_sc.__fv_d_redistr_softmax_1d_dloss_dindepdt
--   (
--     sm_sc.fv_redistr_softmax(array[array[1, 2, 3, 4, 5]] :: float[]),
--     -` array[array[0.0 :: float, 0.0 :: float, 0.0 :: float, 0.0 :: float, 1.0]] /` sm_sc.fv_redistr_softmax(array[array[1, 2, 3, 4, 5]] :: float[]),
--     array[array[1, 2, 3, 4, 5]]
--   );

-- select sm_sc.__fv_d_redistr_softmax_1d_dloss_dindepdt
--   (
--     sm_sc.fv_redistr_softmax(array[array[1], array[2], array[3], array[4], array[5]] :: float[]),
--     -` array[array[0.0 :: float], array[0.0 :: float], array[0.0 :: float], array[0.0 :: float], array[1.0 :: float]] /` sm_sc.fv_redistr_softmax(array[array[1], array[2], array[3], array[4], array[5]] :: float[]),
--     array[array[1], array[2], array[3], array[4], array[5]]
--   );